""" Integrate MixedPrecisionKVCache into Mistral/Llama generation. Compares Naive (uint8) vs Triton (true 4-bit) implementations. """ import torch import json import os import sys import time from datetime import datetime from transformers import AutoTokenizer, AutoModelForCausalLM sys.path.append(os.path.expanduser("~/kv-hack")) from kernel.quant_cache import MixedPrecisionKVCache from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton # ── config ────────────────────────────────────────── MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" MODEL_PATHS = { "mistral-7b": "~/kv-hack/mistral-model", "llama-3-8b": "~/kv-hack/llama-model", } model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") # load bit allocation with open(f"{results_dir}/bit_allocation.json") as f: bit_alloc_raw = json.load(f) bit_alloc = { int(l): [bit_alloc_raw[l][str(h)] for h in range(len(bit_alloc_raw[l]))] for l in bit_alloc_raw } num_layers = len(bit_alloc) all_bits = [b for l in bit_alloc.values() for b in l] avg_bits = sum(all_bits) / len(all_bits) print(f"Model: {MODEL_NAME}") print(f"Layers: {num_layers}") print(f"Avg bits/head: {avg_bits:.2f}") print(f"Theoretical: {16/avg_bits:.2f}x compression") # ── load model ────────────────────────────────────── print(f"\nLoading {MODEL_NAME}...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, dtype=torch.float16, device_map="cuda" ) model.eval() print(f"Model loaded. Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB") # ── core generation function ───────────────────────── def run_quantized_generation(prompt: str, cache_class, max_new_tokens: int = 50): """ Run generation and measure KV cache compression. cache_class: MixedPrecisionKVCache or MixedPrecisionKVCacheTriton """ inputs = tokenizer(prompt, return_tensors="pt").to("cuda") torch.cuda.reset_peak_memory_stats() t0 = time.time() with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id, use_cache=True, ) elapsed = time.time() - t0 peak_mem = torch.cuda.max_memory_allocated() / 1e9 # measure KV cache compression separately with torch.no_grad(): prefill_out = model(**inputs, use_cache=True) kv = prefill_out.past_key_values compressed_bytes = 0 fp16_bytes = 0 for layer_idx in range(num_layers): k = kv.layers[layer_idx].keys v = kv.layers[layer_idx].values fp16_bytes += k.numel() * 2 + v.numel() * 2 cache = cache_class(bit_alloc[layer_idx]) cache.store(k, v) compressed_bytes += cache.memory_bytes() text = tokenizer.decode(out[0], skip_special_tokens=True) return { "text": text, "peak_memory_gb": round(peak_mem, 3), "compressed_kb": round(compressed_bytes / 1024, 1), "fp16_kb": round(fp16_bytes / 1024, 1), "compression_ratio": round(fp16_bytes / compressed_bytes, 2), "tokens_per_sec": round(max_new_tokens / elapsed, 1), "time_sec": round(elapsed, 2), } # ── run comparison ─────────────────────────────────── prompts = [ "The history of artificial intelligence began", "Explain how transformers work in deep learning:", "Write a Python function to sort a list:", ] all_results = { "model": MODEL_NAME, "timestamp": datetime.now().isoformat(), "avg_bits": avg_bits, "theoretical_compression": round(16 / avg_bits, 2), "naive": [], "triton": [], } print("\n" + "="*60) print("NAIVE vs TRITON COMPARISON") print("="*60) for prompt in prompts: print(f"\nPrompt: {prompt[:55]}...") r_naive = run_quantized_generation(prompt, MixedPrecisionKVCache) r_triton = run_quantized_generation(prompt, MixedPrecisionKVCacheTriton) print(f"{'Metric':<22} {'Naive':>12} {'Triton':>12}") print(f"{'-'*48}") print(f"{'Peak memory (GB)':<22} {r_naive['peak_memory_gb']:>12.2f} {r_triton['peak_memory_gb']:>12.2f}") print(f"{'FP16 KV (KB)':<22} {r_naive['fp16_kb']:>12.0f} {r_triton['fp16_kb']:>12.0f}") print(f"{'Compressed KV (KB)':<22} {r_naive['compressed_kb']:>12.1f} {r_triton['compressed_kb']:>12.1f}") print(f"{'Compression ratio':<22} {r_naive['compression_ratio']:>11.2f}x {r_triton['compression_ratio']:>11.2f}x") print(f"{'Tokens/sec':<22} {r_naive['tokens_per_sec']:>12.1f} {r_triton['tokens_per_sec']:>12.1f}") print(f"\nOutput: {r_triton['text'][len(prompt):len(prompt)+120]}") all_results["naive"].append({ "prompt": prompt, "compression_ratio": r_naive["compression_ratio"], "peak_memory_gb": r_naive["peak_memory_gb"], "tokens_per_sec": r_naive["tokens_per_sec"], "compressed_kb": r_naive["compressed_kb"], "fp16_kb": r_naive["fp16_kb"], }) all_results["triton"].append({ "prompt": prompt, "compression_ratio": r_triton["compression_ratio"], "peak_memory_gb": r_triton["peak_memory_gb"], "tokens_per_sec": r_triton["tokens_per_sec"], "compressed_kb": r_triton["compressed_kb"], "fp16_kb": r_triton["fp16_kb"], }) # ── summary ────────────────────────────────────────── print("\n" + "="*60) print("SUMMARY") print("="*60) avg_naive_compression = sum(r["compression_ratio"] for r in all_results["naive"]) / len(prompts) avg_triton_compression = sum(r["compression_ratio"] for r in all_results["triton"]) / len(prompts) avg_naive_speed = sum(r["tokens_per_sec"] for r in all_results["naive"]) / len(prompts) avg_triton_speed = sum(r["tokens_per_sec"] for r in all_results["triton"]) / len(prompts) print(f"{'Metric':<28} {'Naive':>10} {'Triton':>10}") print(f"{'-'*52}") print(f"{'Avg compression ratio':<28} {avg_naive_compression:>9.2f}x {avg_triton_compression:>9.2f}x") print(f"{'Avg tokens/sec':<28} {avg_naive_speed:>10.1f} {avg_triton_speed:>10.1f}") print(f"{'Triton memory improvement':<28} {'':>10} {avg_triton_compression/avg_naive_compression:>9.2f}x") all_results["summary"] = { "avg_naive_compression": round(avg_naive_compression, 2), "avg_triton_compression": round(avg_triton_compression, 2), "avg_naive_speed": round(avg_naive_speed, 1), "avg_triton_speed": round(avg_triton_speed, 1), "triton_memory_improvement": round(avg_triton_compression / avg_naive_compression, 2), } # ── save ───────────────────────────────────────────── out_path = f"{results_dir}/integrate_results.json" with open(out_path, "w") as f: json.dump(all_results, f, indent=2) print(f"\n✅ Results saved to {out_path}")