File size: 4,068 Bytes
9190eff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Long context benchmarks at 16K and 32K.
This is where KV cache compression matters most.
"""
import torch
import json
import os
import sys
import time
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

sys.path.append(os.path.expanduser("~/kv-hack"))
from kernel.quant_cache import MixedPrecisionKVCache

MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
    "mistral-7b": "~/kv-hack/mistral-model",
    "llama-3-8b": "~/kv-hack/llama-model",
}
model_path  = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")

with open(f"{results_dir}/bit_allocation.json") as f:
    raw = json.load(f)
bit_alloc = {
    int(l): [raw[l][str(h)] for h in range(len(raw[l]))]
    for l in raw
}
num_layers = len(bit_alloc)
avg_bits   = sum(b for l in bit_alloc.values() for b in l) / \
             sum(len(l) for l in bit_alloc.values())

print(f"Model: {MODEL_NAME}")
print(f"Avg bits: {avg_bits:.2f}")

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model     = AutoModelForCausalLM.from_pretrained(
    model_path, dtype=torch.float16, device_map="cuda"
)
model.eval()
print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")

def measure_context(context_len: int):
    print(f"\n  Context {context_len} tokens...")
    input_ids = torch.randint(1, 1000, (1, context_len)).cuda()

    # peak memory
    torch.cuda.reset_peak_memory_stats()
    with torch.no_grad():
        out = model(input_ids, use_cache=True)
        kv  = out.past_key_values
    torch.cuda.synchronize()
    peak_mem = torch.cuda.max_memory_allocated() / 1e9

    # KV compression
    fp16_bytes       = 0
    uniform8_bytes   = 0
    compressed_bytes = 0

    for layer_idx in range(num_layers):
        k = kv.layers[layer_idx].keys
        v = kv.layers[layer_idx].values
        fp16_bytes     += k.numel() * 2 + v.numel() * 2
        uniform8_bytes += k.numel() + v.numel()
        cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
        cache.store(k, v)
        compressed_bytes += cache.memory_bytes()

    # prefill speed
    times = []
    for _ in range(3):
        torch.cuda.synchronize()
        t0 = time.time()
        with torch.no_grad():
            _ = model(input_ids, use_cache=True)
        torch.cuda.synchronize()
        times.append(time.time() - t0)
    prefill_ms = round(sum(times) / len(times) * 1000, 1)

    return {
        "context_len":          context_len,
        "peak_memory_gb":       round(peak_mem, 2),
        "fp16_mb":              round(fp16_bytes / 1e6, 2),
        "uniform8_mb":          round(uniform8_bytes / 1e6, 2),
        "mixed_precision_mb":   round(compressed_bytes / 1e6, 2),
        "compression_vs_fp16":  round(fp16_bytes / compressed_bytes, 2),
        "compression_vs_8bit":  round(uniform8_bytes / compressed_bytes, 2),
        "prefill_ms":           prefill_ms,
    }

print("\n" + "="*60)
print("LONG CONTEXT BENCHMARK")
print("="*60)

results = []
for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
    try:
        r = measure_context(ctx)
        results.append(r)
        print(f"  ctx={ctx:6d} | "
              f"mem={r['peak_memory_gb']:.2f}GB | "
              f"FP16={r['fp16_mb']:.0f}MB | "
              f"Ours={r['mixed_precision_mb']:.0f}MB | "
              f"{r['compression_vs_fp16']}x | "
              f"prefill={r['prefill_ms']}ms")
    except torch.cuda.OutOfMemoryError:
        print(f"  ctx={ctx:6d} | OOM — FP16 ran out of memory ✓")
        # still measure our compressed version
        results.append({
            "context_len": ctx,
            "peak_memory_gb": "OOM",
            "fp16_mb": ctx * num_layers * 8 * 128 * 4 / 1e6,
            "note": "FP16 OOM, compressed might fit"
        })
        break

# save
out_path = f"{results_dir}/long_context_results.json"
with open(out_path, "w") as f:
    json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
print(f"\n✅ Saved to {out_path}")