maxsim / benchmarks /benchmark.py
erikkaum's picture
Uploaded using `kernel-builder`.
5fdc1d1 verified
Raw
History Blame
10.5 kB
"""Benchmarks for the MaxSim kernel against naive PyTorch baselines.
This is the file the ``kernels`` CLI discovers and runs::
kernels benchmark erikkaum/maxsim # against the published kernel
just bench-local # against the local ./build
The workloads mirror ``scripts/cuda_bench_matrix.py`` (the README-number
generator) so the CLI tables and the README matrix describe the same shapes:
* Contrastive training -- the headline ColBERT fine-tuning step (forward +
backward), at the in-batch shape an actual PyLate batch hits.
* Padded inference -- the second-stage exact rerank (forward only).
* Packed inference -- the same rerank shape through the ragged pair API,
to surface the packing/layout overhead.
Naive baselines are identical to the matrix's: a ``torch.einsum`` that
materialises the full similarity tensor before ``max`` -- what the kernel is
meant to beat. fp16 inputs throughout; the matrix covers bf16 as well.
"""
from __future__ import annotations
import torch
from kernels.benchmark import Benchmark
_DTYPE = torch.float16
_DIM = 128
_SEED = 1234
# ---------------------------------------------------------------------------
# Naive PyTorch baselines (identical to scripts/cuda_bench_matrix.py).
# ---------------------------------------------------------------------------
def _naive_contrastive(q, docs, cu, d_lens):
"""All-pairs MaxSim over packed docs: materialises [Nq, Nb, Lq, Ld]."""
_Nq, _Lq, dim = q.shape
Nb = len(d_lens)
Ld_max = max(d_lens)
docs_padded = docs.new_zeros((Nb, Ld_max, dim))
offs = cu.to(torch.int64).cpu().tolist()
for i, ld_i in enumerate(d_lens):
docs_padded[i, :ld_i] = docs[offs[i] : offs[i + 1]]
sim = torch.einsum("qid,njd->qnij", q.float(), docs_padded.float())
return sim.max(dim=-1).values.sum(dim=-1)
def _naive_padded(q, d, qlen, dlen):
"""Per-query padded MaxSim: materialises [B, C, Lq, Ld]."""
_B, _C, Ld, _dim = d.shape
Lq = q.shape[1]
sim = torch.einsum("bid,bcjd->bcij", q.float(), d.float())
q_mask = torch.arange(Lq, device=q.device)[None, :] < qlen[:, None]
d_mask = torch.arange(Ld, device=q.device)[None, None, :] < dlen[:, :, None]
sim = sim.masked_fill(~d_mask[:, :, None, :], float("-inf"))
per_q_max = sim.max(dim=-1).values
per_q_max = per_q_max.masked_fill(~q_mask[:, None, :], 0.0)
return per_q_max.sum(dim=-1)
# ---------------------------------------------------------------------------
# Workload construction (shared, keyed off each Benchmark's shape attrs).
# ---------------------------------------------------------------------------
def _make_contrastive(self: Benchmark, Nq, Nb, Lq, Ld) -> None:
gen = torch.Generator().manual_seed(self.seed)
self.queries = torch.randn(Nq, Lq, _DIM, generator=gen, dtype=_DTYPE).to(self.device)
self.documents = torch.randn(Nb * Ld, _DIM, generator=gen, dtype=_DTYPE).to(self.device)
self.document_offsets = torch.arange(
0, (Nb + 1) * Ld, Ld, dtype=torch.int32, device=self.device
)
self.d_lens = [Ld] * Nb
def _make_padded(self: Benchmark, B, C, Lq, Ld) -> None:
gen = torch.Generator().manual_seed(self.seed)
self.queries = torch.randn(B, Lq, _DIM, generator=gen, dtype=_DTYPE).to(self.device)
self.documents = torch.randn(B, C, Ld, _DIM, generator=gen, dtype=_DTYPE).to(self.device)
self.query_lengths = torch.full((B,), Lq, dtype=torch.int32, device=self.device)
self.doc_lengths = torch.full((B, C), Ld, dtype=torch.int32, device=self.device)
def _make_packed(self: Benchmark, B, C, Lq, Ld) -> None:
"""Padded tensors plus a flattened CSR pair grid expressing the same work."""
_make_padded(self, B, C, Lq, Ld)
q, d = self.queries, self.documents
self.max_q_len = Lq
self.batch, self.candidates = B, C
self.q_flat = q.reshape(B * Lq, _DIM).contiguous()
self.d_flat = d.reshape(B * C * Ld, _DIM).contiguous()
self.q_offsets = torch.arange(0, (B + 1) * Lq, Lq, dtype=torch.int32, device=q.device)
self.d_offsets = torch.arange(0, (B * C + 1) * Ld, Ld, dtype=torch.int32, device=q.device)
pair_ids = torch.arange(B * C, dtype=torch.int32, device=q.device)
self.pair_query_ids = pair_ids // C
self.pair_document_ids = pair_ids
# ---------------------------------------------------------------------------
# Kernel / naive runners.
# ---------------------------------------------------------------------------
def _contrastive_train_kernel(self: Benchmark) -> torch.Tensor:
q = self.queries.detach().clone().requires_grad_(True)
d = self.documents.detach().clone().requires_grad_(True)
scores = self.kernel.score_contrastive_train(q, d, self.document_offsets)
scores.sum().backward()
return scores.detach()
def _contrastive_train_naive(self: Benchmark) -> torch.Tensor:
q = self.queries.detach().clone().requires_grad_(True)
d = self.documents.detach().clone().requires_grad_(True)
scores = _naive_contrastive(q, d, self.document_offsets, self.d_lens)
scores.sum().backward()
return scores.detach()
def _contrastive_ref(self: Benchmark) -> torch.Tensor:
return _naive_contrastive(
self.queries, self.documents, self.document_offsets, self.d_lens
)
def _padded_kernel(self: Benchmark) -> torch.Tensor:
return self.kernel.score_candidates_padded(
self.queries, self.documents, self.query_lengths, self.doc_lengths
)
def _padded_naive(self: Benchmark) -> torch.Tensor:
return _naive_padded(
self.queries, self.documents, self.query_lengths, self.doc_lengths
)
def _packed_kernel(self: Benchmark) -> torch.Tensor:
return self.kernel.score_pairs_packed(
self.q_flat,
self.q_offsets,
self.d_flat,
self.d_offsets,
self.pair_query_ids,
self.pair_document_ids,
max_q_len=self.max_q_len,
).view(self.batch, self.candidates)
# ---------------------------------------------------------------------------
# Contrastive training (forward + backward; the headline workload).
# ---------------------------------------------------------------------------
class ContrastiveLateOn(Benchmark):
"""In-batch contrastive training: Nq=Nb=32, Lq=32, Ld=80, dim=128."""
seed = _SEED
def setup(self) -> None:
_make_contrastive(self, Nq=32, Nb=32, Lq=32, Ld=80)
def benchmark_kernel(self) -> None:
self.out = _contrastive_train_kernel(self)
def benchmark_naive(self) -> None:
self.out = _contrastive_train_naive(self)
def verify_kernel(self) -> torch.Tensor:
return _contrastive_ref(self)
def verify_naive(self) -> torch.Tensor:
return _contrastive_ref(self)
class ContrastiveLongDocs(Benchmark):
"""Same in-batch shape but long docs (Ld=512) -- stresses retained state."""
seed = _SEED
def setup(self) -> None:
_make_contrastive(self, Nq=32, Nb=32, Lq=32, Ld=512)
def benchmark_kernel(self) -> None:
self.out = _contrastive_train_kernel(self)
def benchmark_naive(self) -> None:
self.out = _contrastive_train_naive(self)
def verify_kernel(self) -> torch.Tensor:
return _contrastive_ref(self)
def verify_naive(self) -> torch.Tensor:
return _contrastive_ref(self)
class ContrastiveBigBatch(Benchmark):
"""Doubled in-batch batch size: Nq=Nb=64, Lq=32, Ld=128, dim=128."""
seed = _SEED
def setup(self) -> None:
_make_contrastive(self, Nq=64, Nb=64, Lq=32, Ld=128)
def benchmark_kernel(self) -> None:
self.out = _contrastive_train_kernel(self)
def benchmark_naive(self) -> None:
self.out = _contrastive_train_naive(self)
def verify_kernel(self) -> torch.Tensor:
return _contrastive_ref(self)
def verify_naive(self) -> torch.Tensor:
return _contrastive_ref(self)
# ---------------------------------------------------------------------------
# Padded inference (second-stage rerank; forward only).
# ---------------------------------------------------------------------------
class PaddedRerank(Benchmark):
"""Padded rerank at a typical inference shape: B=32, K=50, Ld=180."""
seed = _SEED
def setup(self) -> None:
_make_padded(self, B=32, C=50, Lq=32, Ld=180)
def benchmark_kernel(self) -> None:
self.out = _padded_kernel(self)
def benchmark_naive(self) -> None:
self.out = _padded_naive(self)
def verify_kernel(self) -> torch.Tensor:
return _padded_naive(self)
def verify_naive(self) -> torch.Tensor:
return _padded_naive(self)
class PaddedHeavyRerank(Benchmark):
"""Padded rerank at K=100 candidates, Ld=256 -- larger compute envelope."""
seed = _SEED
def setup(self) -> None:
_make_padded(self, B=32, C=100, Lq=32, Ld=256)
def benchmark_kernel(self) -> None:
self.out = _padded_kernel(self)
def benchmark_naive(self) -> None:
self.out = _padded_naive(self)
def verify_kernel(self) -> torch.Tensor:
return _padded_naive(self)
def verify_naive(self) -> torch.Tensor:
return _padded_naive(self)
# ---------------------------------------------------------------------------
# Packed inference (same rerank shape via the ragged pair API; forward only).
# ---------------------------------------------------------------------------
class PackedRerank(Benchmark):
"""Rerank shape (B=32, K=50, Ld=180) expressed through the packed pair API."""
seed = _SEED
def setup(self) -> None:
_make_packed(self, B=32, C=50, Lq=32, Ld=180)
def benchmark_kernel(self) -> None:
self.out = _packed_kernel(self)
def benchmark_naive(self) -> None:
self.out = _padded_naive(self)
def verify_kernel(self) -> torch.Tensor:
return _padded_naive(self)
def verify_naive(self) -> torch.Tensor:
return _padded_naive(self)
class PackedHeavyRerank(Benchmark):
"""Rerank shape (B=32, K=100, Ld=256) through the packed pair API."""
seed = _SEED
def setup(self) -> None:
_make_packed(self, B=32, C=100, Lq=32, Ld=256)
def benchmark_kernel(self) -> None:
self.out = _packed_kernel(self)
def benchmark_naive(self) -> None:
self.out = _padded_naive(self)
def verify_kernel(self) -> torch.Tensor:
return _padded_naive(self)
def verify_naive(self) -> torch.Tensor:
return _padded_naive(self)