| """Benchmarks for the MaxSim kernel against naive PyTorch baselines. |
| |
| This is the file the ``kernels`` CLI discovers and runs:: |
| |
| kernels benchmark erikkaum/maxsim # against the published kernel |
| just bench-local # against the local ./build |
| |
| The workloads mirror ``scripts/cuda_bench_matrix.py`` (the README-number |
| generator) so the CLI tables and the README matrix describe the same shapes: |
| |
| * Contrastive training -- the headline ColBERT fine-tuning step (forward + |
| backward), at the in-batch shape an actual PyLate batch hits. |
| * Padded inference -- the second-stage exact rerank (forward only). |
| * Packed inference -- the same rerank shape through the ragged pair API, |
| to surface the packing/layout overhead. |
| |
| Naive baselines are identical to the matrix's: a ``torch.einsum`` that |
| materialises the full similarity tensor before ``max`` -- what the kernel is |
| meant to beat. fp16 inputs throughout; the matrix covers bf16 as well. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from kernels.benchmark import Benchmark |
|
|
| _DTYPE = torch.float16 |
| _DIM = 128 |
| _SEED = 1234 |
|
|
|
|
| |
| |
| |
|
|
| def _naive_contrastive(q, docs, cu, d_lens): |
| """All-pairs MaxSim over packed docs: materialises [Nq, Nb, Lq, Ld].""" |
| _Nq, _Lq, dim = q.shape |
| Nb = len(d_lens) |
| Ld_max = max(d_lens) |
| docs_padded = docs.new_zeros((Nb, Ld_max, dim)) |
| offs = cu.to(torch.int64).cpu().tolist() |
| for i, ld_i in enumerate(d_lens): |
| docs_padded[i, :ld_i] = docs[offs[i] : offs[i + 1]] |
| sim = torch.einsum("qid,njd->qnij", q.float(), docs_padded.float()) |
| return sim.max(dim=-1).values.sum(dim=-1) |
|
|
|
|
| def _naive_padded(q, d, qlen, dlen): |
| """Per-query padded MaxSim: materialises [B, C, Lq, Ld].""" |
| _B, _C, Ld, _dim = d.shape |
| Lq = q.shape[1] |
| sim = torch.einsum("bid,bcjd->bcij", q.float(), d.float()) |
| q_mask = torch.arange(Lq, device=q.device)[None, :] < qlen[:, None] |
| d_mask = torch.arange(Ld, device=q.device)[None, None, :] < dlen[:, :, None] |
| sim = sim.masked_fill(~d_mask[:, :, None, :], float("-inf")) |
| per_q_max = sim.max(dim=-1).values |
| per_q_max = per_q_max.masked_fill(~q_mask[:, None, :], 0.0) |
| return per_q_max.sum(dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| def _make_contrastive(self: Benchmark, Nq, Nb, Lq, Ld) -> None: |
| gen = torch.Generator().manual_seed(self.seed) |
| self.queries = torch.randn(Nq, Lq, _DIM, generator=gen, dtype=_DTYPE).to(self.device) |
| self.documents = torch.randn(Nb * Ld, _DIM, generator=gen, dtype=_DTYPE).to(self.device) |
| self.document_offsets = torch.arange( |
| 0, (Nb + 1) * Ld, Ld, dtype=torch.int32, device=self.device |
| ) |
| self.d_lens = [Ld] * Nb |
|
|
|
|
| def _make_padded(self: Benchmark, B, C, Lq, Ld) -> None: |
| gen = torch.Generator().manual_seed(self.seed) |
| self.queries = torch.randn(B, Lq, _DIM, generator=gen, dtype=_DTYPE).to(self.device) |
| self.documents = torch.randn(B, C, Ld, _DIM, generator=gen, dtype=_DTYPE).to(self.device) |
| self.query_lengths = torch.full((B,), Lq, dtype=torch.int32, device=self.device) |
| self.doc_lengths = torch.full((B, C), Ld, dtype=torch.int32, device=self.device) |
|
|
|
|
| def _make_packed(self: Benchmark, B, C, Lq, Ld) -> None: |
| """Padded tensors plus a flattened CSR pair grid expressing the same work.""" |
| _make_padded(self, B, C, Lq, Ld) |
| q, d = self.queries, self.documents |
| self.max_q_len = Lq |
| self.batch, self.candidates = B, C |
| self.q_flat = q.reshape(B * Lq, _DIM).contiguous() |
| self.d_flat = d.reshape(B * C * Ld, _DIM).contiguous() |
| self.q_offsets = torch.arange(0, (B + 1) * Lq, Lq, dtype=torch.int32, device=q.device) |
| self.d_offsets = torch.arange(0, (B * C + 1) * Ld, Ld, dtype=torch.int32, device=q.device) |
| pair_ids = torch.arange(B * C, dtype=torch.int32, device=q.device) |
| self.pair_query_ids = pair_ids // C |
| self.pair_document_ids = pair_ids |
|
|
|
|
| |
| |
| |
|
|
| def _contrastive_train_kernel(self: Benchmark) -> torch.Tensor: |
| q = self.queries.detach().clone().requires_grad_(True) |
| d = self.documents.detach().clone().requires_grad_(True) |
| scores = self.kernel.score_contrastive_train(q, d, self.document_offsets) |
| scores.sum().backward() |
| return scores.detach() |
|
|
|
|
| def _contrastive_train_naive(self: Benchmark) -> torch.Tensor: |
| q = self.queries.detach().clone().requires_grad_(True) |
| d = self.documents.detach().clone().requires_grad_(True) |
| scores = _naive_contrastive(q, d, self.document_offsets, self.d_lens) |
| scores.sum().backward() |
| return scores.detach() |
|
|
|
|
| def _contrastive_ref(self: Benchmark) -> torch.Tensor: |
| return _naive_contrastive( |
| self.queries, self.documents, self.document_offsets, self.d_lens |
| ) |
|
|
|
|
| def _padded_kernel(self: Benchmark) -> torch.Tensor: |
| return self.kernel.score_candidates_padded( |
| self.queries, self.documents, self.query_lengths, self.doc_lengths |
| ) |
|
|
|
|
| def _padded_naive(self: Benchmark) -> torch.Tensor: |
| return _naive_padded( |
| self.queries, self.documents, self.query_lengths, self.doc_lengths |
| ) |
|
|
|
|
| def _packed_kernel(self: Benchmark) -> torch.Tensor: |
| return self.kernel.score_pairs_packed( |
| self.q_flat, |
| self.q_offsets, |
| self.d_flat, |
| self.d_offsets, |
| self.pair_query_ids, |
| self.pair_document_ids, |
| max_q_len=self.max_q_len, |
| ).view(self.batch, self.candidates) |
|
|
|
|
| |
| |
| |
|
|
| class ContrastiveLateOn(Benchmark): |
| """In-batch contrastive training: Nq=Nb=32, Lq=32, Ld=80, dim=128.""" |
|
|
| seed = _SEED |
|
|
| def setup(self) -> None: |
| _make_contrastive(self, Nq=32, Nb=32, Lq=32, Ld=80) |
|
|
| def benchmark_kernel(self) -> None: |
| self.out = _contrastive_train_kernel(self) |
|
|
| def benchmark_naive(self) -> None: |
| self.out = _contrastive_train_naive(self) |
|
|
| def verify_kernel(self) -> torch.Tensor: |
| return _contrastive_ref(self) |
|
|
| def verify_naive(self) -> torch.Tensor: |
| return _contrastive_ref(self) |
|
|
|
|
| class ContrastiveLongDocs(Benchmark): |
| """Same in-batch shape but long docs (Ld=512) -- stresses retained state.""" |
|
|
| seed = _SEED |
|
|
| def setup(self) -> None: |
| _make_contrastive(self, Nq=32, Nb=32, Lq=32, Ld=512) |
|
|
| def benchmark_kernel(self) -> None: |
| self.out = _contrastive_train_kernel(self) |
|
|
| def benchmark_naive(self) -> None: |
| self.out = _contrastive_train_naive(self) |
|
|
| def verify_kernel(self) -> torch.Tensor: |
| return _contrastive_ref(self) |
|
|
| def verify_naive(self) -> torch.Tensor: |
| return _contrastive_ref(self) |
|
|
|
|
| class ContrastiveBigBatch(Benchmark): |
| """Doubled in-batch batch size: Nq=Nb=64, Lq=32, Ld=128, dim=128.""" |
|
|
| seed = _SEED |
|
|
| def setup(self) -> None: |
| _make_contrastive(self, Nq=64, Nb=64, Lq=32, Ld=128) |
|
|
| def benchmark_kernel(self) -> None: |
| self.out = _contrastive_train_kernel(self) |
|
|
| def benchmark_naive(self) -> None: |
| self.out = _contrastive_train_naive(self) |
|
|
| def verify_kernel(self) -> torch.Tensor: |
| return _contrastive_ref(self) |
|
|
| def verify_naive(self) -> torch.Tensor: |
| return _contrastive_ref(self) |
|
|
|
|
| |
| |
| |
|
|
| class PaddedRerank(Benchmark): |
| """Padded rerank at a typical inference shape: B=32, K=50, Ld=180.""" |
|
|
| seed = _SEED |
|
|
| def setup(self) -> None: |
| _make_padded(self, B=32, C=50, Lq=32, Ld=180) |
|
|
| def benchmark_kernel(self) -> None: |
| self.out = _padded_kernel(self) |
|
|
| def benchmark_naive(self) -> None: |
| self.out = _padded_naive(self) |
|
|
| def verify_kernel(self) -> torch.Tensor: |
| return _padded_naive(self) |
|
|
| def verify_naive(self) -> torch.Tensor: |
| return _padded_naive(self) |
|
|
|
|
| class PaddedHeavyRerank(Benchmark): |
| """Padded rerank at K=100 candidates, Ld=256 -- larger compute envelope.""" |
|
|
| seed = _SEED |
|
|
| def setup(self) -> None: |
| _make_padded(self, B=32, C=100, Lq=32, Ld=256) |
|
|
| def benchmark_kernel(self) -> None: |
| self.out = _padded_kernel(self) |
|
|
| def benchmark_naive(self) -> None: |
| self.out = _padded_naive(self) |
|
|
| def verify_kernel(self) -> torch.Tensor: |
| return _padded_naive(self) |
|
|
| def verify_naive(self) -> torch.Tensor: |
| return _padded_naive(self) |
|
|
|
|
| |
| |
| |
|
|
| class PackedRerank(Benchmark): |
| """Rerank shape (B=32, K=50, Ld=180) expressed through the packed pair API.""" |
|
|
| seed = _SEED |
|
|
| def setup(self) -> None: |
| _make_packed(self, B=32, C=50, Lq=32, Ld=180) |
|
|
| def benchmark_kernel(self) -> None: |
| self.out = _packed_kernel(self) |
|
|
| def benchmark_naive(self) -> None: |
| self.out = _padded_naive(self) |
|
|
| def verify_kernel(self) -> torch.Tensor: |
| return _padded_naive(self) |
|
|
| def verify_naive(self) -> torch.Tensor: |
| return _padded_naive(self) |
|
|
|
|
| class PackedHeavyRerank(Benchmark): |
| """Rerank shape (B=32, K=100, Ld=256) through the packed pair API.""" |
|
|
| seed = _SEED |
|
|
| def setup(self) -> None: |
| _make_packed(self, B=32, C=100, Lq=32, Ld=256) |
|
|
| def benchmark_kernel(self) -> None: |
| self.out = _packed_kernel(self) |
|
|
| def benchmark_naive(self) -> None: |
| self.out = _padded_naive(self) |
|
|
| def verify_kernel(self) -> torch.Tensor: |
| return _padded_naive(self) |
|
|
| def verify_naive(self) -> torch.Tensor: |
| return _padded_naive(self) |
|
|