#!/usr/bin/env python3 """Compare SpectralQuant codebook Triton kernels against Python fallbacks. This is a focused kernel equivalence test. It initializes the real Gemma 4 sidecar, then checks local and global layers for: - compress: packed cache bytes match exactly - compress: stored fp16 norms match exactly - dequant: bf16 dequantized output matches exactly Run on the Lightning Docker image with the active vllm-spectral checkout on PYTHONPATH, for example: cd /workspace/vllm-spectral PYTHONPATH=/workspace/vllm-spectral:/workspace/vllm-spectral/vllm/third_party \ python3 /workspace/gemmacut/test_triton_codebook_match.py """ from __future__ import annotations import argparse import torch def python_compress( spectral, data: torch.Tensor, norms: torch.Tensor, slots: torch.Tensor, pack_maps: tuple[torch.Tensor, ...], codebooks, kv: str, layer_idx: int, head_offset: int, cache: torch.Tensor, norm_buf: torch.Tensor, block_size: int, ) -> None: H = data.shape[1] D_cache = cache.shape[-1] packed_dim = spectral._PACKED_DIMS[layer_idx] hi_src, lo_src, is_sem, has_lo, valid = pack_maps if kv == "k": indices = spectral._quantize_all_heads( data, codebooks.k_semantic_centroids, codebooks.k_tail_centroids, codebooks.k_d_eff_int, ) kv_idx = 0 else: indices = spectral._quantize_all_heads( data, codebooks.v_semantic_centroids, codebooks.v_tail_centroids, codebooks.v_d_eff_int, ) kv_idx = 1 packed = torch.zeros( data.shape[0], H, D_cache, dtype=torch.uint8, device=data.device ) packed[:, :, :packed_dim] = spectral._pack_all_heads( indices, hi_src, lo_src, is_sem, has_lo, valid ) good = slots >= 0 good_slots = slots[good] block_idx = good_slots // block_size block_off = good_slots % block_size cache[block_idx, block_off] = packed[good] norm_buf[good_slots, head_offset : head_offset + H, kv_idx] = norms[good].to( torch.float16 ) def python_dequant( spectral, cache: torch.Tensor, norm_buf: torch.Tensor, unique_blocks: torch.Tensor, unpack_map: tuple[torch.Tensor, ...], codebooks, kv: str, layer_idx: int, head_offset: int, H: int, D: int, block_size: int, ) -> torch.Tensor: src, is_sem, is_high = unpack_map out = torch.zeros( unique_blocks.shape[0] * block_size, H, D, dtype=torch.bfloat16, device=cache.device, ) valid_prog = unique_blocks >= 0 if not valid_prog.any(): return out blocks = unique_blocks[valid_prog] raw = cache[blocks].reshape(-1, H, cache.shape[-1]) indices = spectral._unpack_all_heads( raw[:, :, : spectral._PACKED_DIMS[layer_idx]], src, is_sem, is_high ) if kv == "k": vals = spectral._dequantize_all_heads( indices, codebooks.k_semantic_centroids, codebooks.k_tail_centroids, codebooks.k_d_eff_int, ) kv_idx = 0 else: vals = spectral._dequantize_all_heads( indices, codebooks.v_semantic_centroids, codebooks.v_tail_centroids, codebooks.v_d_eff_int, ) kv_idx = 1 offsets = torch.arange(block_size, device=cache.device) slot_indices = (blocks.unsqueeze(1) * block_size + offsets.unsqueeze(0)).reshape(-1) vals = vals * norm_buf[ slot_indices, head_offset : head_offset + H, kv_idx ].float().unsqueeze(-1) valid_rows = torch.nonzero(valid_prog, as_tuple=False).flatten() out_view = out.reshape(unique_blocks.shape[0], block_size, H, D) out_view[valid_rows] = vals.reshape(blocks.shape[0], block_size, H, D).to( torch.bfloat16 ) return out def check_layer( spectral, cal, layer_idx: int, block_size: int, num_blocks: int, slots: torch.Tensor, unique_blocks: torch.Tensor, device: str, ) -> list[str]: failures: list[str] = [] lc = cal.get_layer(layer_idx) if lc is None: return [f"layer {layer_idx} missing from calibration"] codebooks = spectral._LAYER_CODEBOOKS[layer_idx] H, D = lc.num_kv_heads, lc.head_dim D_cache = spectral._ALLOC_DIMS[layer_idx] head_offset = spectral._NORM_BUFFER_LAYER_OFFSETS[layer_idx] print( f"CHECK layer={layer_idx} type={lc.layer_type} H={H} D={D} " f"packed={spectral._PACKED_DIMS[layer_idx]} alloc={D_cache}", flush=True, ) for kv in ("k", "v"): pack_maps = spectral._PACK_MAPS[(layer_idx, kv)] unpack_map = spectral._UNPACK_MAPS[(layer_idx, kv)] # Values are already in the rotated, normalized basis here. This isolates # codebook quantization, packing, norm storage, and dequantization. data = (torch.randn(slots.shape[0], H, D, device=device) * 0.07).contiguous() norms = (torch.rand(slots.shape[0], H, device=device) * 3.0 + 0.25).contiguous() cache_py = torch.zeros( num_blocks, block_size, H, D_cache, dtype=torch.uint8, device=device ) cache_tri = torch.zeros_like(cache_py) norm_py = torch.zeros_like(spectral._NORM_BUFFER) norm_tri = torch.zeros_like(spectral._NORM_BUFFER) python_compress( spectral, data, norms, slots, pack_maps, codebooks, kv, layer_idx, head_offset, cache_py, norm_py, block_size, ) spectral._NORM_BUFFER = norm_tri spectral._triton_compress( data, norms, slots, pack_maps, codebooks, kv, cache_tri, layer_idx, head_offset ) torch.cuda.synchronize() if not torch.equal(cache_py, cache_tri): diff = cache_py != cache_tri idx = diff.nonzero(as_tuple=False)[0].tolist() failures.append( f"compress cache mismatch layer={layer_idx} kv={kv} " f"first_idx={idx} py={int(cache_py[tuple(idx)])} " f"tri={int(cache_tri[tuple(idx)])} count={int(diff.sum())}" ) kv_idx = 0 if kv == "k" else 1 norm_slice = (slice(None), slice(head_offset, head_offset + H), kv_idx) if not torch.equal(norm_py[norm_slice], norm_tri[norm_slice]): max_diff = ( norm_py[norm_slice].float() - norm_tri[norm_slice].float() ).abs().max().item() failures.append( f"norm mismatch layer={layer_idx} kv={kv} max_abs={max_diff}" ) out_py = python_dequant( spectral, cache_py, norm_py, unique_blocks, unpack_map, codebooks, kv, layer_idx, head_offset, H, D, block_size, ) out_tri = torch.zeros_like(out_py) spectral._NORM_BUFFER = norm_py spectral._triton_dequant( cache_py, unique_blocks, unpack_map, codebooks, kv, head_offset, out_tri, block_size, H, D, max_blocks=unique_blocks.shape[0], ) torch.cuda.synchronize() if not torch.equal(out_py, out_tri): abs_diff = (out_py.float() - out_tri.float()).abs() nz = (out_py != out_tri).nonzero(as_tuple=False) first = nz[0].tolist() if nz.numel() else None failures.append( f"dequant output mismatch layer={layer_idx} kv={kv} " f"max_abs={abs_diff.max().item()} first_idx={first} " f"count={int((out_py != out_tri).sum())}" ) else: print( f"PASS layer={layer_idx} kv={kv}: compress bytes exact, " "norms exact, dequant bf16 exact", flush=True, ) return failures def main() -> int: parser = argparse.ArgumentParser() parser.add_argument( "--sidecar", default="/workspace/gemmacut/results_it/spectral_sidecar_chat_v2.pt", ) parser.add_argument("--layers", default="0,5", help="comma-separated layer ids") parser.add_argument("--device", default="cuda") args = parser.parse_args() from vllm.v1.attention import spectral torch.manual_seed(1234) spectral.init_spectral(args.sidecar, spectral_quantize=True, device=args.device) spectral.init_norm_buffer(4096, device=args.device) cal = spectral.get_calibration() if cal is None: raise RuntimeError("spectral calibration did not load") block_size = 16 num_blocks = 9 slots = torch.tensor( [0, 1, 2, 15, 16, 17, -1, 31, 32, 45, 46, 47, 63, 64, -1, 80, 81, 95, 96], device=args.device, dtype=torch.long, ) unique_blocks = torch.tensor( [0, 1, 2, 3, 4, 5, 6, -1], device=args.device, dtype=torch.long ) failures: list[str] = [] for layer_idx in [int(x) for x in args.layers.split(",") if x.strip()]: failures.extend( check_layer( spectral, cal, layer_idx, block_size, num_blocks, slots, unique_blocks, args.device, ) ) spectral._NORM_BUFFER = None for failure in failures: print("FAIL", failure, flush=True) if failures: return 1 print("ALL_MATCH", flush=True) return 0 if __name__ == "__main__": raise SystemExit(main())