| """ |
| Integrate MixedPrecisionKVCache into Mistral/Llama generation. |
| Compares Naive (uint8) vs Triton (true 4-bit) implementations. |
| """ |
| import torch |
| import json |
| import os |
| import sys |
| import time |
| from datetime import datetime |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| sys.path.append(os.path.expanduser("~/kv-hack")) |
| from kernel.quant_cache import MixedPrecisionKVCache |
| from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton |
|
|
| |
| MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" |
| MODEL_PATHS = { |
| "mistral-7b": "~/kv-hack/mistral-model", |
| "llama-3-8b": "~/kv-hack/llama-model", |
| } |
| model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) |
| results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") |
|
|
| |
| with open(f"{results_dir}/bit_allocation.json") as f: |
| bit_alloc_raw = json.load(f) |
|
|
| bit_alloc = { |
| int(l): [bit_alloc_raw[l][str(h)] |
| for h in range(len(bit_alloc_raw[l]))] |
| for l in bit_alloc_raw |
| } |
| num_layers = len(bit_alloc) |
|
|
| all_bits = [b for l in bit_alloc.values() for b in l] |
| avg_bits = sum(all_bits) / len(all_bits) |
|
|
| print(f"Model: {MODEL_NAME}") |
| print(f"Layers: {num_layers}") |
| print(f"Avg bits/head: {avg_bits:.2f}") |
| print(f"Theoretical: {16/avg_bits:.2f}x compression") |
|
|
| |
| print(f"\nLoading {MODEL_NAME}...") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, dtype=torch.float16, device_map="cuda" |
| ) |
| model.eval() |
| print(f"Model loaded. Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB") |
|
|
|
|
| |
| def run_quantized_generation(prompt: str, cache_class, max_new_tokens: int = 50): |
| """ |
| Run generation and measure KV cache compression. |
| cache_class: MixedPrecisionKVCache or MixedPrecisionKVCacheTriton |
| """ |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
|
| torch.cuda.reset_peak_memory_stats() |
| t0 = time.time() |
|
|
| with torch.no_grad(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id, |
| use_cache=True, |
| ) |
|
|
| elapsed = time.time() - t0 |
| peak_mem = torch.cuda.max_memory_allocated() / 1e9 |
|
|
| |
| with torch.no_grad(): |
| prefill_out = model(**inputs, use_cache=True) |
| kv = prefill_out.past_key_values |
|
|
| compressed_bytes = 0 |
| fp16_bytes = 0 |
| for layer_idx in range(num_layers): |
| k = kv.layers[layer_idx].keys |
| v = kv.layers[layer_idx].values |
| fp16_bytes += k.numel() * 2 + v.numel() * 2 |
| cache = cache_class(bit_alloc[layer_idx]) |
| cache.store(k, v) |
| compressed_bytes += cache.memory_bytes() |
|
|
| text = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
| return { |
| "text": text, |
| "peak_memory_gb": round(peak_mem, 3), |
| "compressed_kb": round(compressed_bytes / 1024, 1), |
| "fp16_kb": round(fp16_bytes / 1024, 1), |
| "compression_ratio": round(fp16_bytes / compressed_bytes, 2), |
| "tokens_per_sec": round(max_new_tokens / elapsed, 1), |
| "time_sec": round(elapsed, 2), |
| } |
|
|
|
|
| |
| prompts = [ |
| "The history of artificial intelligence began", |
| "Explain how transformers work in deep learning:", |
| "Write a Python function to sort a list:", |
| ] |
|
|
| all_results = { |
| "model": MODEL_NAME, |
| "timestamp": datetime.now().isoformat(), |
| "avg_bits": avg_bits, |
| "theoretical_compression": round(16 / avg_bits, 2), |
| "naive": [], |
| "triton": [], |
| } |
|
|
| print("\n" + "="*60) |
| print("NAIVE vs TRITON COMPARISON") |
| print("="*60) |
|
|
| for prompt in prompts: |
| print(f"\nPrompt: {prompt[:55]}...") |
|
|
| r_naive = run_quantized_generation(prompt, MixedPrecisionKVCache) |
| r_triton = run_quantized_generation(prompt, MixedPrecisionKVCacheTriton) |
|
|
| print(f"{'Metric':<22} {'Naive':>12} {'Triton':>12}") |
| print(f"{'-'*48}") |
| print(f"{'Peak memory (GB)':<22} {r_naive['peak_memory_gb']:>12.2f} {r_triton['peak_memory_gb']:>12.2f}") |
| print(f"{'FP16 KV (KB)':<22} {r_naive['fp16_kb']:>12.0f} {r_triton['fp16_kb']:>12.0f}") |
| print(f"{'Compressed KV (KB)':<22} {r_naive['compressed_kb']:>12.1f} {r_triton['compressed_kb']:>12.1f}") |
| print(f"{'Compression ratio':<22} {r_naive['compression_ratio']:>11.2f}x {r_triton['compression_ratio']:>11.2f}x") |
| print(f"{'Tokens/sec':<22} {r_naive['tokens_per_sec']:>12.1f} {r_triton['tokens_per_sec']:>12.1f}") |
| print(f"\nOutput: {r_triton['text'][len(prompt):len(prompt)+120]}") |
|
|
| all_results["naive"].append({ |
| "prompt": prompt, |
| "compression_ratio": r_naive["compression_ratio"], |
| "peak_memory_gb": r_naive["peak_memory_gb"], |
| "tokens_per_sec": r_naive["tokens_per_sec"], |
| "compressed_kb": r_naive["compressed_kb"], |
| "fp16_kb": r_naive["fp16_kb"], |
| }) |
| all_results["triton"].append({ |
| "prompt": prompt, |
| "compression_ratio": r_triton["compression_ratio"], |
| "peak_memory_gb": r_triton["peak_memory_gb"], |
| "tokens_per_sec": r_triton["tokens_per_sec"], |
| "compressed_kb": r_triton["compressed_kb"], |
| "fp16_kb": r_triton["fp16_kb"], |
| }) |
|
|
| |
| print("\n" + "="*60) |
| print("SUMMARY") |
| print("="*60) |
| avg_naive_compression = sum(r["compression_ratio"] for r in all_results["naive"]) / len(prompts) |
| avg_triton_compression = sum(r["compression_ratio"] for r in all_results["triton"]) / len(prompts) |
| avg_naive_speed = sum(r["tokens_per_sec"] for r in all_results["naive"]) / len(prompts) |
| avg_triton_speed = sum(r["tokens_per_sec"] for r in all_results["triton"]) / len(prompts) |
|
|
| print(f"{'Metric':<28} {'Naive':>10} {'Triton':>10}") |
| print(f"{'-'*52}") |
| print(f"{'Avg compression ratio':<28} {avg_naive_compression:>9.2f}x {avg_triton_compression:>9.2f}x") |
| print(f"{'Avg tokens/sec':<28} {avg_naive_speed:>10.1f} {avg_triton_speed:>10.1f}") |
| print(f"{'Triton memory improvement':<28} {'':>10} {avg_triton_compression/avg_naive_compression:>9.2f}x") |
|
|
| all_results["summary"] = { |
| "avg_naive_compression": round(avg_naive_compression, 2), |
| "avg_triton_compression": round(avg_triton_compression, 2), |
| "avg_naive_speed": round(avg_naive_speed, 1), |
| "avg_triton_speed": round(avg_triton_speed, 1), |
| "triton_memory_improvement": round(avg_triton_compression / avg_naive_compression, 2), |
| } |
|
|
| |
| out_path = f"{results_dir}/integrate_results.json" |
| with open(out_path, "w") as f: |
| json.dump(all_results, f, indent=2) |
|
|
| print(f"\nβ
Results saved to {out_path}") |