| |
| """Measure realized SpectralQuant KV-cache footprint. |
| |
| This is not a throughput benchmark. It initializes the real vLLM engine and |
| prints the actual KV cache tensor allocation from the live kv_cache_config, |
| then compares it to the equivalent fp8 full-head KV cache layout for Gemma 4. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import is_dataclass |
| from typing import Any |
|
|
| from vllm import LLM |
|
|
|
|
| def _get_attr(obj: Any, name: str, default: Any = None) -> Any: |
| return getattr(obj, name, default) |
|
|
|
|
| def _bytes(num: float) -> str: |
| mib = num / (1024**2) |
| gib = num / (1024**3) |
| if gib >= 1: |
| return f"{gib:.3f} GiB" |
| return f"{mib:.3f} MiB" |
|
|
|
|
| def _tensor_bytes(tensors: dict[str, Any]) -> int: |
| total = 0 |
| for tensor in tensors.values(): |
| total += tensor.numel() * tensor.element_size() |
| return total |
|
|
|
|
| def _iter_layer_specs(cfg: Any): |
| for group in cfg.kv_cache_groups: |
| spec = group.kv_cache_spec |
| per_layer = _get_attr(spec, "kv_cache_specs") |
| if per_layer is None: |
| for layer_name in group.layer_names: |
| yield layer_name, spec |
| else: |
| for layer_name in group.layer_names: |
| yield layer_name, per_layer[layer_name] |
|
|
|
|
| def _worker_from_llm(llm: LLM) -> Any: |
| executor = llm.llm_engine.model_executor |
| return executor.driver_worker.worker |
|
|
|
|
| def main() -> None: |
| llm = LLM( |
| model="Intel/gemma-4-31B-it-int4-AutoRound", |
| spectral_calibration="/workspace/gemmacut/results_it/spectral_sidecar_chat_v2.pt", |
| spectral_quantize=True, |
| kv_cache_dtype="fp8_e4m3", |
| max_model_len=512, |
| max_num_batched_tokens=512, |
| max_num_seqs=1, |
| gpu_memory_utilization=0.8, |
| compilation_config={"compile_sizes": []}, |
| ) |
|
|
| worker = _worker_from_llm(llm) |
| runner = worker.model_runner |
| cfg = runner.kv_cache_config |
|
|
| kv_tensor_bytes = sum(t.size for t in cfg.kv_cache_tensors) |
| kv_tensor_per_block = kv_tensor_bytes / cfg.num_blocks |
|
|
| print("MEASURE spectral_quantize=True kv_cache_dtype=fp8_e4m3") |
| print(f"num_blocks={cfg.num_blocks}") |
| print(f"kv_cache_tensors={len(cfg.kv_cache_tensors)}") |
| print( |
| f"kv_tensor_bytes={kv_tensor_bytes} ({_bytes(kv_tensor_bytes)}) " |
| f"per_block={kv_tensor_per_block:.0f}" |
| ) |
|
|
| for i, tensor in enumerate(cfg.kv_cache_tensors): |
| if i >= 5: |
| remaining = len(cfg.kv_cache_tensors) - i |
| print(f"tensor[{i}..] {remaining} more tensors omitted") |
| break |
| print( |
| f"tensor[{i}] size={tensor.size} shared_by={len(tensor.shared_by)} " |
| f"first={tensor.shared_by[:2]}" |
| ) |
|
|
| for i, group in enumerate(cfg.kv_cache_groups): |
| spec = group.kv_cache_spec |
| fields = { |
| "layers": len(group.layer_names), |
| "spec": type(spec).__name__, |
| "page_size": _get_attr(spec, "page_size_bytes"), |
| "real_page_size": _get_attr(spec, "real_page_size_bytes"), |
| "block_size": _get_attr(spec, "block_size"), |
| "num_kv_heads": _get_attr(spec, "num_kv_heads"), |
| "head_size": _get_attr(spec, "head_size"), |
| } |
| if is_dataclass(spec): |
| fields["repr"] = repr(spec) |
| print(f"group[{i}] {fields}") |
|
|
| from vllm.v1.attention import spectral |
|
|
| norm = spectral._NORM_BUFFER |
| norm_bytes = 0 if norm is None else norm.numel() * norm.element_size() |
| norm_per_block = norm_bytes / cfg.num_blocks if cfg.num_blocks else 0 |
| print( |
| f"norm_buffer_bytes={norm_bytes} ({_bytes(norm_bytes)}) " |
| f"per_block={norm_per_block:.0f} shape={None if norm is None else tuple(norm.shape)}" |
| ) |
|
|
| scratch_bytes = 0 |
| scratch_bytes += _tensor_bytes(spectral._DEQUANT_KEY_BUF) |
| scratch_bytes += _tensor_bytes(spectral._DEQUANT_VAL_BUF) |
| scratch_bytes += _tensor_bytes(spectral._DEQUANT_REMAP) |
| scratch_bytes += _tensor_bytes(spectral._DEQUANT_ACTIVE_MASK) |
| scratch_bytes += _tensor_bytes(spectral._DEQUANT_BLOCK_LIST) |
| scratch_bytes += _tensor_bytes(spectral._ROTATE_BUF_K) |
| scratch_bytes += _tensor_bytes(spectral._ROTATE_BUF_V) |
| scratch_bytes += _tensor_bytes(spectral._ROTATE_NORMS_K) |
| scratch_bytes += _tensor_bytes(spectral._ROTATE_NORMS_V) |
| print( |
| f"scratch_dequant_rotate_bytes={scratch_bytes} ({_bytes(scratch_bytes)}) " |
| "not counted as persistent KV cache" |
| ) |
|
|
| |
| |
| baseline_per_block = 0 |
| baseline_layers = 0 |
| cal = spectral.get_calibration() |
| if cal is not None: |
| for layer_name, spec in _iter_layer_specs(cfg): |
| layer_idx = spectral._extract_layer_index(layer_name) |
| lc = cal.get_layer(layer_idx) |
| if lc is None: |
| real_page_size = _get_attr(spec, "real_page_size_bytes", 0) |
| baseline_per_block += real_page_size |
| else: |
| baseline_per_block += ( |
| 2 * spec.block_size * lc.num_kv_heads * lc.head_dim |
| ) |
| baseline_layers += 1 |
| if baseline_per_block == 0: |
| |
| |
| block_sizes = [ |
| _get_attr(group.kv_cache_spec, "block_size") |
| for group in cfg.kv_cache_groups |
| if _get_attr(group.kv_cache_spec, "block_size") is not None |
| ] |
| base_block_size = min(block_sizes) if block_sizes else 16 |
| baseline_per_block = ( |
| 50 * 2 * base_block_size * 16 * 256 |
| + 10 * 2 * base_block_size * 4 * 512 |
| ) |
| baseline_layers = 60 |
| baseline_total = baseline_per_block * cfg.num_blocks |
|
|
| codebook_plus_norm = kv_tensor_bytes + norm_bytes |
| print( |
| f"baseline_fp8_full_head_bytes={baseline_total} ({_bytes(baseline_total)}) " |
| f"per_block={baseline_per_block} layers={baseline_layers}" |
| ) |
| print( |
| f"codebook_kv_plus_norm_bytes={codebook_plus_norm} " |
| f"({_bytes(codebook_plus_norm)}) " |
| f"per_block={kv_tensor_per_block + norm_per_block:.0f}" |
| ) |
| print( |
| "compression_vs_fp8_full_head " |
| f"kv_tensor_only={baseline_total / kv_tensor_bytes:.4f}x " |
| f"savings={100 * (1 - kv_tensor_bytes / baseline_total):.2f}%" |
| ) |
| print( |
| "compression_vs_fp8_full_head " |
| f"kv_plus_norm={baseline_total / codebook_plus_norm:.4f}x " |
| f"savings={100 * (1 - codebook_plus_norm / baseline_total):.2f}%" |
| ) |
| print( |
| "compression_vs_bf16_full_head " |
| f"kv_plus_norm={(2 * baseline_total) / codebook_plus_norm:.4f}x " |
| f"savings={100 * (1 - codebook_plus_norm / (2 * baseline_total)):.2f}%" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|