gemmacut-spectral / scripts /test_triton_codebook_match.py
satya007's picture
Add files using upload-large-folder tool
af7d321 verified
#!/usr/bin/env python3
"""Compare SpectralQuant codebook Triton kernels against Python fallbacks.
This is a focused kernel equivalence test. It initializes the real Gemma 4
sidecar, then checks local and global layers for:
- compress: packed cache bytes match exactly
- compress: stored fp16 norms match exactly
- dequant: bf16 dequantized output matches exactly
Run on the Lightning Docker image with the active vllm-spectral checkout on
PYTHONPATH, for example:
cd /workspace/vllm-spectral
PYTHONPATH=/workspace/vllm-spectral:/workspace/vllm-spectral/vllm/third_party \
python3 /workspace/gemmacut/test_triton_codebook_match.py
"""
from __future__ import annotations
import argparse
import torch
def python_compress(
spectral,
data: torch.Tensor,
norms: torch.Tensor,
slots: torch.Tensor,
pack_maps: tuple[torch.Tensor, ...],
codebooks,
kv: str,
layer_idx: int,
head_offset: int,
cache: torch.Tensor,
norm_buf: torch.Tensor,
block_size: int,
) -> None:
H = data.shape[1]
D_cache = cache.shape[-1]
packed_dim = spectral._PACKED_DIMS[layer_idx]
hi_src, lo_src, is_sem, has_lo, valid = pack_maps
if kv == "k":
indices = spectral._quantize_all_heads(
data,
codebooks.k_semantic_centroids,
codebooks.k_tail_centroids,
codebooks.k_d_eff_int,
)
kv_idx = 0
else:
indices = spectral._quantize_all_heads(
data,
codebooks.v_semantic_centroids,
codebooks.v_tail_centroids,
codebooks.v_d_eff_int,
)
kv_idx = 1
packed = torch.zeros(
data.shape[0], H, D_cache, dtype=torch.uint8, device=data.device
)
packed[:, :, :packed_dim] = spectral._pack_all_heads(
indices, hi_src, lo_src, is_sem, has_lo, valid
)
good = slots >= 0
good_slots = slots[good]
block_idx = good_slots // block_size
block_off = good_slots % block_size
cache[block_idx, block_off] = packed[good]
norm_buf[good_slots, head_offset : head_offset + H, kv_idx] = norms[good].to(
torch.float16
)
def python_dequant(
spectral,
cache: torch.Tensor,
norm_buf: torch.Tensor,
unique_blocks: torch.Tensor,
unpack_map: tuple[torch.Tensor, ...],
codebooks,
kv: str,
layer_idx: int,
head_offset: int,
H: int,
D: int,
block_size: int,
) -> torch.Tensor:
src, is_sem, is_high = unpack_map
out = torch.zeros(
unique_blocks.shape[0] * block_size,
H,
D,
dtype=torch.bfloat16,
device=cache.device,
)
valid_prog = unique_blocks >= 0
if not valid_prog.any():
return out
blocks = unique_blocks[valid_prog]
raw = cache[blocks].reshape(-1, H, cache.shape[-1])
indices = spectral._unpack_all_heads(
raw[:, :, : spectral._PACKED_DIMS[layer_idx]], src, is_sem, is_high
)
if kv == "k":
vals = spectral._dequantize_all_heads(
indices,
codebooks.k_semantic_centroids,
codebooks.k_tail_centroids,
codebooks.k_d_eff_int,
)
kv_idx = 0
else:
vals = spectral._dequantize_all_heads(
indices,
codebooks.v_semantic_centroids,
codebooks.v_tail_centroids,
codebooks.v_d_eff_int,
)
kv_idx = 1
offsets = torch.arange(block_size, device=cache.device)
slot_indices = (blocks.unsqueeze(1) * block_size + offsets.unsqueeze(0)).reshape(-1)
vals = vals * norm_buf[
slot_indices, head_offset : head_offset + H, kv_idx
].float().unsqueeze(-1)
valid_rows = torch.nonzero(valid_prog, as_tuple=False).flatten()
out_view = out.reshape(unique_blocks.shape[0], block_size, H, D)
out_view[valid_rows] = vals.reshape(blocks.shape[0], block_size, H, D).to(
torch.bfloat16
)
return out
def check_layer(
spectral,
cal,
layer_idx: int,
block_size: int,
num_blocks: int,
slots: torch.Tensor,
unique_blocks: torch.Tensor,
device: str,
) -> list[str]:
failures: list[str] = []
lc = cal.get_layer(layer_idx)
if lc is None:
return [f"layer {layer_idx} missing from calibration"]
codebooks = spectral._LAYER_CODEBOOKS[layer_idx]
H, D = lc.num_kv_heads, lc.head_dim
D_cache = spectral._ALLOC_DIMS[layer_idx]
head_offset = spectral._NORM_BUFFER_LAYER_OFFSETS[layer_idx]
print(
f"CHECK layer={layer_idx} type={lc.layer_type} H={H} D={D} "
f"packed={spectral._PACKED_DIMS[layer_idx]} alloc={D_cache}",
flush=True,
)
for kv in ("k", "v"):
pack_maps = spectral._PACK_MAPS[(layer_idx, kv)]
unpack_map = spectral._UNPACK_MAPS[(layer_idx, kv)]
# Values are already in the rotated, normalized basis here. This isolates
# codebook quantization, packing, norm storage, and dequantization.
data = (torch.randn(slots.shape[0], H, D, device=device) * 0.07).contiguous()
norms = (torch.rand(slots.shape[0], H, device=device) * 3.0 + 0.25).contiguous()
cache_py = torch.zeros(
num_blocks, block_size, H, D_cache, dtype=torch.uint8, device=device
)
cache_tri = torch.zeros_like(cache_py)
norm_py = torch.zeros_like(spectral._NORM_BUFFER)
norm_tri = torch.zeros_like(spectral._NORM_BUFFER)
python_compress(
spectral,
data,
norms,
slots,
pack_maps,
codebooks,
kv,
layer_idx,
head_offset,
cache_py,
norm_py,
block_size,
)
spectral._NORM_BUFFER = norm_tri
spectral._triton_compress(
data, norms, slots, pack_maps, codebooks, kv, cache_tri, layer_idx, head_offset
)
torch.cuda.synchronize()
if not torch.equal(cache_py, cache_tri):
diff = cache_py != cache_tri
idx = diff.nonzero(as_tuple=False)[0].tolist()
failures.append(
f"compress cache mismatch layer={layer_idx} kv={kv} "
f"first_idx={idx} py={int(cache_py[tuple(idx)])} "
f"tri={int(cache_tri[tuple(idx)])} count={int(diff.sum())}"
)
kv_idx = 0 if kv == "k" else 1
norm_slice = (slice(None), slice(head_offset, head_offset + H), kv_idx)
if not torch.equal(norm_py[norm_slice], norm_tri[norm_slice]):
max_diff = (
norm_py[norm_slice].float() - norm_tri[norm_slice].float()
).abs().max().item()
failures.append(
f"norm mismatch layer={layer_idx} kv={kv} max_abs={max_diff}"
)
out_py = python_dequant(
spectral,
cache_py,
norm_py,
unique_blocks,
unpack_map,
codebooks,
kv,
layer_idx,
head_offset,
H,
D,
block_size,
)
out_tri = torch.zeros_like(out_py)
spectral._NORM_BUFFER = norm_py
spectral._triton_dequant(
cache_py,
unique_blocks,
unpack_map,
codebooks,
kv,
head_offset,
out_tri,
block_size,
H,
D,
max_blocks=unique_blocks.shape[0],
)
torch.cuda.synchronize()
if not torch.equal(out_py, out_tri):
abs_diff = (out_py.float() - out_tri.float()).abs()
nz = (out_py != out_tri).nonzero(as_tuple=False)
first = nz[0].tolist() if nz.numel() else None
failures.append(
f"dequant output mismatch layer={layer_idx} kv={kv} "
f"max_abs={abs_diff.max().item()} first_idx={first} "
f"count={int((out_py != out_tri).sum())}"
)
else:
print(
f"PASS layer={layer_idx} kv={kv}: compress bytes exact, "
"norms exact, dequant bf16 exact",
flush=True,
)
return failures
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
"--sidecar",
default="/workspace/gemmacut/results_it/spectral_sidecar_chat_v2.pt",
)
parser.add_argument("--layers", default="0,5", help="comma-separated layer ids")
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
from vllm.v1.attention import spectral
torch.manual_seed(1234)
spectral.init_spectral(args.sidecar, spectral_quantize=True, device=args.device)
spectral.init_norm_buffer(4096, device=args.device)
cal = spectral.get_calibration()
if cal is None:
raise RuntimeError("spectral calibration did not load")
block_size = 16
num_blocks = 9
slots = torch.tensor(
[0, 1, 2, 15, 16, 17, -1, 31, 32, 45, 46, 47, 63, 64, -1, 80, 81, 95, 96],
device=args.device,
dtype=torch.long,
)
unique_blocks = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, -1], device=args.device, dtype=torch.long
)
failures: list[str] = []
for layer_idx in [int(x) for x in args.layers.split(",") if x.strip()]:
failures.extend(
check_layer(
spectral,
cal,
layer_idx,
block_size,
num_blocks,
slots,
unique_blocks,
args.device,
)
)
spectral._NORM_BUFFER = None
for failure in failures:
print("FAIL", failure, flush=True)
if failures:
return 1
print("ALL_MATCH", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())