hydra / tests /test_csr.py
Frosty40's picture
Publish Hydra kernel source packet
7298fd0 verified
from __future__ import annotations
import pytest
pytest.importorskip("torch")
pytest.importorskip("triton")
def test_dense_causal_csr_cpu_contract():
from hydra.csr import build_dense_causal_csr
row_ptr, col_idx, seq_lens = build_dense_causal_csr(
batch_size=1,
num_heads=2,
seq_len=128,
block_size=32,
device="cpu",
)
assert row_ptr.shape == (1, 2, 5)
assert seq_lens.tolist() == [128]
assert row_ptr[0, 0].tolist() == [0, 1, 3, 6, 10]
assert col_idx[0, 0].tolist() == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
def test_sliding_window_csr_keeps_diagonal_last():
from hydra.csr import build_sliding_window_csr
row_ptr, col_idx, _ = build_sliding_window_csr(
window=64,
seq_len=128,
block_size=32,
batch_size=1,
num_heads=1,
device="cpu",
)
rp = row_ptr[0, 0].tolist()
ci = col_idx[0, 0].tolist()
for q_block in range(4):
lo, hi = rp[q_block], rp[q_block + 1]
assert ci[hi - 1] == q_block
assert all(k < q_block for k in ci[lo : hi - 1])