| |
| """Compare SpectralQuant codebook Triton kernels against Python fallbacks. |
| |
| This is a focused kernel equivalence test. It initializes the real Gemma 4 |
| sidecar, then checks local and global layers for: |
| |
| - compress: packed cache bytes match exactly |
| - compress: stored fp16 norms match exactly |
| - dequant: bf16 dequantized output matches exactly |
| |
| Run on the Lightning Docker image with the active vllm-spectral checkout on |
| PYTHONPATH, for example: |
| |
| cd /workspace/vllm-spectral |
| PYTHONPATH=/workspace/vllm-spectral:/workspace/vllm-spectral/vllm/third_party \ |
| python3 /workspace/gemmacut/test_triton_codebook_match.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
|
|
| import torch |
|
|
|
|
| def python_compress( |
| spectral, |
| data: torch.Tensor, |
| norms: torch.Tensor, |
| slots: torch.Tensor, |
| pack_maps: tuple[torch.Tensor, ...], |
| codebooks, |
| kv: str, |
| layer_idx: int, |
| head_offset: int, |
| cache: torch.Tensor, |
| norm_buf: torch.Tensor, |
| block_size: int, |
| ) -> None: |
| H = data.shape[1] |
| D_cache = cache.shape[-1] |
| packed_dim = spectral._PACKED_DIMS[layer_idx] |
| hi_src, lo_src, is_sem, has_lo, valid = pack_maps |
|
|
| if kv == "k": |
| indices = spectral._quantize_all_heads( |
| data, |
| codebooks.k_semantic_centroids, |
| codebooks.k_tail_centroids, |
| codebooks.k_d_eff_int, |
| ) |
| kv_idx = 0 |
| else: |
| indices = spectral._quantize_all_heads( |
| data, |
| codebooks.v_semantic_centroids, |
| codebooks.v_tail_centroids, |
| codebooks.v_d_eff_int, |
| ) |
| kv_idx = 1 |
|
|
| packed = torch.zeros( |
| data.shape[0], H, D_cache, dtype=torch.uint8, device=data.device |
| ) |
| packed[:, :, :packed_dim] = spectral._pack_all_heads( |
| indices, hi_src, lo_src, is_sem, has_lo, valid |
| ) |
|
|
| good = slots >= 0 |
| good_slots = slots[good] |
| block_idx = good_slots // block_size |
| block_off = good_slots % block_size |
| cache[block_idx, block_off] = packed[good] |
| norm_buf[good_slots, head_offset : head_offset + H, kv_idx] = norms[good].to( |
| torch.float16 |
| ) |
|
|
|
|
| def python_dequant( |
| spectral, |
| cache: torch.Tensor, |
| norm_buf: torch.Tensor, |
| unique_blocks: torch.Tensor, |
| unpack_map: tuple[torch.Tensor, ...], |
| codebooks, |
| kv: str, |
| layer_idx: int, |
| head_offset: int, |
| H: int, |
| D: int, |
| block_size: int, |
| ) -> torch.Tensor: |
| src, is_sem, is_high = unpack_map |
| out = torch.zeros( |
| unique_blocks.shape[0] * block_size, |
| H, |
| D, |
| dtype=torch.bfloat16, |
| device=cache.device, |
| ) |
|
|
| valid_prog = unique_blocks >= 0 |
| if not valid_prog.any(): |
| return out |
|
|
| blocks = unique_blocks[valid_prog] |
| raw = cache[blocks].reshape(-1, H, cache.shape[-1]) |
| indices = spectral._unpack_all_heads( |
| raw[:, :, : spectral._PACKED_DIMS[layer_idx]], src, is_sem, is_high |
| ) |
|
|
| if kv == "k": |
| vals = spectral._dequantize_all_heads( |
| indices, |
| codebooks.k_semantic_centroids, |
| codebooks.k_tail_centroids, |
| codebooks.k_d_eff_int, |
| ) |
| kv_idx = 0 |
| else: |
| vals = spectral._dequantize_all_heads( |
| indices, |
| codebooks.v_semantic_centroids, |
| codebooks.v_tail_centroids, |
| codebooks.v_d_eff_int, |
| ) |
| kv_idx = 1 |
|
|
| offsets = torch.arange(block_size, device=cache.device) |
| slot_indices = (blocks.unsqueeze(1) * block_size + offsets.unsqueeze(0)).reshape(-1) |
| vals = vals * norm_buf[ |
| slot_indices, head_offset : head_offset + H, kv_idx |
| ].float().unsqueeze(-1) |
|
|
| valid_rows = torch.nonzero(valid_prog, as_tuple=False).flatten() |
| out_view = out.reshape(unique_blocks.shape[0], block_size, H, D) |
| out_view[valid_rows] = vals.reshape(blocks.shape[0], block_size, H, D).to( |
| torch.bfloat16 |
| ) |
| return out |
|
|
|
|
| def check_layer( |
| spectral, |
| cal, |
| layer_idx: int, |
| block_size: int, |
| num_blocks: int, |
| slots: torch.Tensor, |
| unique_blocks: torch.Tensor, |
| device: str, |
| ) -> list[str]: |
| failures: list[str] = [] |
| lc = cal.get_layer(layer_idx) |
| if lc is None: |
| return [f"layer {layer_idx} missing from calibration"] |
|
|
| codebooks = spectral._LAYER_CODEBOOKS[layer_idx] |
| H, D = lc.num_kv_heads, lc.head_dim |
| D_cache = spectral._ALLOC_DIMS[layer_idx] |
| head_offset = spectral._NORM_BUFFER_LAYER_OFFSETS[layer_idx] |
|
|
| print( |
| f"CHECK layer={layer_idx} type={lc.layer_type} H={H} D={D} " |
| f"packed={spectral._PACKED_DIMS[layer_idx]} alloc={D_cache}", |
| flush=True, |
| ) |
|
|
| for kv in ("k", "v"): |
| pack_maps = spectral._PACK_MAPS[(layer_idx, kv)] |
| unpack_map = spectral._UNPACK_MAPS[(layer_idx, kv)] |
|
|
| |
| |
| data = (torch.randn(slots.shape[0], H, D, device=device) * 0.07).contiguous() |
| norms = (torch.rand(slots.shape[0], H, device=device) * 3.0 + 0.25).contiguous() |
|
|
| cache_py = torch.zeros( |
| num_blocks, block_size, H, D_cache, dtype=torch.uint8, device=device |
| ) |
| cache_tri = torch.zeros_like(cache_py) |
| norm_py = torch.zeros_like(spectral._NORM_BUFFER) |
| norm_tri = torch.zeros_like(spectral._NORM_BUFFER) |
|
|
| python_compress( |
| spectral, |
| data, |
| norms, |
| slots, |
| pack_maps, |
| codebooks, |
| kv, |
| layer_idx, |
| head_offset, |
| cache_py, |
| norm_py, |
| block_size, |
| ) |
|
|
| spectral._NORM_BUFFER = norm_tri |
| spectral._triton_compress( |
| data, norms, slots, pack_maps, codebooks, kv, cache_tri, layer_idx, head_offset |
| ) |
| torch.cuda.synchronize() |
|
|
| if not torch.equal(cache_py, cache_tri): |
| diff = cache_py != cache_tri |
| idx = diff.nonzero(as_tuple=False)[0].tolist() |
| failures.append( |
| f"compress cache mismatch layer={layer_idx} kv={kv} " |
| f"first_idx={idx} py={int(cache_py[tuple(idx)])} " |
| f"tri={int(cache_tri[tuple(idx)])} count={int(diff.sum())}" |
| ) |
|
|
| kv_idx = 0 if kv == "k" else 1 |
| norm_slice = (slice(None), slice(head_offset, head_offset + H), kv_idx) |
| if not torch.equal(norm_py[norm_slice], norm_tri[norm_slice]): |
| max_diff = ( |
| norm_py[norm_slice].float() - norm_tri[norm_slice].float() |
| ).abs().max().item() |
| failures.append( |
| f"norm mismatch layer={layer_idx} kv={kv} max_abs={max_diff}" |
| ) |
|
|
| out_py = python_dequant( |
| spectral, |
| cache_py, |
| norm_py, |
| unique_blocks, |
| unpack_map, |
| codebooks, |
| kv, |
| layer_idx, |
| head_offset, |
| H, |
| D, |
| block_size, |
| ) |
| out_tri = torch.zeros_like(out_py) |
| spectral._NORM_BUFFER = norm_py |
| spectral._triton_dequant( |
| cache_py, |
| unique_blocks, |
| unpack_map, |
| codebooks, |
| kv, |
| head_offset, |
| out_tri, |
| block_size, |
| H, |
| D, |
| max_blocks=unique_blocks.shape[0], |
| ) |
| torch.cuda.synchronize() |
|
|
| if not torch.equal(out_py, out_tri): |
| abs_diff = (out_py.float() - out_tri.float()).abs() |
| nz = (out_py != out_tri).nonzero(as_tuple=False) |
| first = nz[0].tolist() if nz.numel() else None |
| failures.append( |
| f"dequant output mismatch layer={layer_idx} kv={kv} " |
| f"max_abs={abs_diff.max().item()} first_idx={first} " |
| f"count={int((out_py != out_tri).sum())}" |
| ) |
| else: |
| print( |
| f"PASS layer={layer_idx} kv={kv}: compress bytes exact, " |
| "norms exact, dequant bf16 exact", |
| flush=True, |
| ) |
|
|
| return failures |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--sidecar", |
| default="/workspace/gemmacut/results_it/spectral_sidecar_chat_v2.pt", |
| ) |
| parser.add_argument("--layers", default="0,5", help="comma-separated layer ids") |
| parser.add_argument("--device", default="cuda") |
| args = parser.parse_args() |
|
|
| from vllm.v1.attention import spectral |
|
|
| torch.manual_seed(1234) |
| spectral.init_spectral(args.sidecar, spectral_quantize=True, device=args.device) |
| spectral.init_norm_buffer(4096, device=args.device) |
| cal = spectral.get_calibration() |
| if cal is None: |
| raise RuntimeError("spectral calibration did not load") |
|
|
| block_size = 16 |
| num_blocks = 9 |
| slots = torch.tensor( |
| [0, 1, 2, 15, 16, 17, -1, 31, 32, 45, 46, 47, 63, 64, -1, 80, 81, 95, 96], |
| device=args.device, |
| dtype=torch.long, |
| ) |
| unique_blocks = torch.tensor( |
| [0, 1, 2, 3, 4, 5, 6, -1], device=args.device, dtype=torch.long |
| ) |
|
|
| failures: list[str] = [] |
| for layer_idx in [int(x) for x in args.layers.split(",") if x.strip()]: |
| failures.extend( |
| check_layer( |
| spectral, |
| cal, |
| layer_idx, |
| block_size, |
| num_blocks, |
| slots, |
| unique_blocks, |
| args.device, |
| ) |
| ) |
|
|
| spectral._NORM_BUFFER = None |
| for failure in failures: |
| print("FAIL", failure, flush=True) |
| if failures: |
| return 1 |
|
|
| print("ALL_MATCH", flush=True) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|