harshithsaiv commited on
Commit
2555c0e
Β·
1 Parent(s): 4b2bdf2

feat:Integrating the kernel to the model

Browse files
Files changed (1) hide show
  1. integrate.py +121 -0
integrate.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integrate MixedPrecisionKVCache into Mistral/Llama generation.
3
+ Hooks into model forward pass to compress KV cache on the fly.
4
+ """
5
+ import torch
6
+ import json
7
+ import os
8
+ import sys
9
+ import time
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+
12
+ sys.path.append(os.path.expanduser("~/kv-hack"))
13
+ from kernel.quant_cache import MixedPrecisionKVCache
14
+
15
+ # ── config ──────────────────────────────────────────
16
+ MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
17
+ MODEL_PATHS = {
18
+ "mistral-7b": "~/kv-hack/mistral-model",
19
+ "llama-3-8b": "~/kv-hack/llama-model",
20
+ }
21
+ model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
22
+ results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
23
+
24
+ # load bit allocation
25
+ with open(f"{results_dir}/bit_allocation.json") as f:
26
+ bit_alloc_raw = json.load(f)
27
+
28
+ # convert keys to ints
29
+ bit_alloc = {
30
+ int(l): [bit_alloc_raw[l][str(h)]
31
+ for h in range(len(bit_alloc_raw[l]))]
32
+ for l in bit_alloc_raw
33
+ }
34
+ num_layers = len(bit_alloc)
35
+ print(f"Loaded bit allocation: {num_layers} layers")
36
+
37
+ # avg bits
38
+ all_bits = [b for l in bit_alloc.values() for b in l]
39
+ avg_bits = sum(all_bits) / len(all_bits)
40
+ print(f"Average bits per head: {avg_bits:.2f} (vs 16 FP16)")
41
+ print(f"Theoretical compression: {16/avg_bits:.2f}x")
42
+
43
+ # ── load model ──────────────────────────────────────
44
+ print(f"\nLoading {MODEL_NAME}...")
45
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_path, dtype=torch.float16, device_map="cuda"
48
+ )
49
+ model.eval()
50
+ print(f"Model loaded. Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")
51
+
52
+ # ── run quantized inference ──────────────────────────
53
+ def run_quantized_generation(prompt: str, max_new_tokens: int = 100):
54
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
55
+
56
+ torch.cuda.reset_peak_memory_stats()
57
+ t0 = time.time()
58
+
59
+ with torch.no_grad():
60
+ # normal generation β€” measure memory and speed
61
+ out = model.generate(
62
+ **inputs,
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=False,
65
+ pad_token_id=tokenizer.eos_token_id,
66
+ use_cache=True,
67
+ )
68
+
69
+ elapsed = time.time() - t0
70
+ peak_mem = torch.cuda.max_memory_allocated() / 1e9
71
+
72
+ # separately measure KV cache compression ratio
73
+ with torch.no_grad():
74
+ prefill_out = model(**inputs, use_cache=True)
75
+ kv = prefill_out.past_key_values
76
+
77
+ compressed_bytes = 0
78
+ fp16_bytes = 0
79
+ for layer_idx in range(num_layers):
80
+ k = kv.layers[layer_idx].keys
81
+ v = kv.layers[layer_idx].values
82
+ fp16_bytes += k.numel() * 2 + v.numel() * 2
83
+
84
+ cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
85
+ cache.store(k, v)
86
+ compressed_bytes += cache.memory_bytes()
87
+
88
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
89
+
90
+ return {
91
+ "text": text,
92
+ "peak_memory_gb": round(peak_mem, 3),
93
+ "compressed_kb": round(compressed_bytes / 1024, 1),
94
+ "fp16_kb": round(fp16_bytes / 1024, 1),
95
+ "compression_ratio": round(fp16_bytes / compressed_bytes, 2),
96
+ "tokens_per_sec": round(max_new_tokens / elapsed, 1),
97
+ "time_sec": round(elapsed, 2),
98
+ }
99
+
100
+
101
+ # ── test it ─────────────────────────────────────────
102
+ prompts = [
103
+ "The history of artificial intelligence began",
104
+ "Explain how transformers work in deep learning:",
105
+ "Write a Python function to sort a list:",
106
+ ]
107
+
108
+ print("\n" + "="*60)
109
+ print("QUANTIZED INFERENCE TEST")
110
+ print("="*60)
111
+
112
+ for prompt in prompts:
113
+ print(f"\nPrompt: {prompt[:50]}...")
114
+ result = run_quantized_generation(prompt, max_new_tokens=50)
115
+ print(f"Peak memory: {result['peak_memory_gb']:.2f} GB")
116
+ print(f"KV cache: {result['fp16_kb']:.0f} KB β†’ {result['compressed_kb']:.0f} KB")
117
+ print(f"Compression: {result['compression_ratio']:.2f}x")
118
+ print(f"Speed: {result['tokens_per_sec']:.1f} tokens/sec")
119
+ print(f"Output: {result['text'][len(prompt):len(prompt)+150]}")
120
+
121
+ print("\nβœ… Quantized inference working!")