File size: 4,848 Bytes
1a0124b
 
 
0774ec2
1a0124b
 
 
 
 
 
 
 
 
 
0774ec2
1a0124b
0774ec2
1a0124b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0774ec2
1a0124b
 
 
 
 
 
 
 
 
 
0774ec2
1a0124b
 
 
 
 
 
 
 
 
 
 
 
0774ec2
1a0124b
 
0774ec2
 
1a0124b
 
 
 
0774ec2
1a0124b
 
 
0774ec2
 
 
 
 
 
 
 
 
1a0124b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0774ec2
 
 
 
1a0124b
 
 
0774ec2
 
 
 
 
1a0124b
 
 
 
 
 
 
 
 
0774ec2
 
 
1a0124b
 
0774ec2
1a0124b
0774ec2
1a0124b
0774ec2
 
1a0124b
 
 
 
 
 
 
0774ec2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Long context benchmarks at 16K and 32K.
This is where KV cache compression matters most.
4 methods: FP16, Uniform 8-bit, Naive Per-Head, Triton True 4-bit
"""
import torch
import json
import os
import sys
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

sys.path.append(os.path.expanduser("~/kv-hack"))
from kernel.quant_cache import MixedPrecisionKVCache
from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton

# ── config ──────────────────────────────────────────
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
    "mistral-7b": "~/kv-hack/mistral-model",
    "llama-3-8b": "~/kv-hack/llama-model",
}
model_path  = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")

with open(f"{results_dir}/bit_allocation.json") as f:
    raw = json.load(f)
bit_alloc = {
    int(l): [raw[l][str(h)] for h in range(len(raw[l]))]
    for l in raw
}
num_layers = len(bit_alloc)
avg_bits   = sum(b for l in bit_alloc.values() for b in l) / \
             sum(len(l) for l in bit_alloc.values())

print(f"Model:    {MODEL_NAME}")
print(f"Avg bits: {avg_bits:.2f}")

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model     = AutoModelForCausalLM.from_pretrained(
    model_path, dtype=torch.float16, device_map="cuda"
)
model.eval()
print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")


def measure_context(context_len: int):
    print(f"\n  Context {context_len} tokens...")
    input_ids = torch.randint(1, 1000, (1, context_len)).cuda()

    # peak memory
    torch.cuda.reset_peak_memory_stats()
    with torch.no_grad():
        out = model(input_ids, use_cache=True)
        kv  = out.past_key_values
    torch.cuda.synchronize()
    peak_mem = torch.cuda.max_memory_allocated() / 1e9

    # KV compression β€” all 4 methods
    fp16_bytes       = 0
    uniform8_bytes   = 0
    naive_real_bytes = 0
    triton_bytes     = 0

    for layer_idx in range(num_layers):
        k = kv.layers[layer_idx].keys
        v = kv.layers[layer_idx].values

        fp16_bytes     += k.numel() * 2 + v.numel() * 2
        uniform8_bytes += k.numel() + v.numel()

        cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx])
        cache_naive.store(k, v)
        naive_real_bytes += cache_naive.real_gpu_bytes()

        cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx])
        cache_triton.store(k, v)
        triton_bytes += cache_triton.memory_bytes()

    # prefill speed (3 runs average)
    times = []
    for _ in range(3):
        torch.cuda.synchronize()
        t0 = time.time()
        with torch.no_grad():
            _ = model(input_ids, use_cache=True)
        torch.cuda.synchronize()
        times.append(time.time() - t0)
    prefill_ms = round(sum(times) / len(times) * 1000, 1)

    return {
        "context_len":          context_len,
        "peak_memory_gb":       round(peak_mem, 2),
        "fp16_mb":              round(fp16_bytes / 1e6, 2),
        "uniform8_mb":          round(uniform8_bytes / 1e6, 2),
        "naive_real_gpu_mb":    round(naive_real_bytes / 1e6, 2),
        "triton_mb":            round(triton_bytes / 1e6, 2),
        "naive_compression":    round(fp16_bytes / naive_real_bytes, 2),
        "triton_compression":   round(fp16_bytes / triton_bytes, 2),
        "prefill_ms":           prefill_ms,
    }


# ── RUN ──────────────────────────────────────────────
print("\n" + "="*75)
print("LONG CONTEXT BENCHMARK β€” 4 METHODS")
print("="*75)

results = []
for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
    try:
        r = measure_context(ctx)
        results.append(r)
        print(f"  ctx={ctx:6d} | "
              f"mem={r['peak_memory_gb']:.2f}GB | "
              f"FP16={r['fp16_mb']:.0f}MB | "
              f"8bit={r['uniform8_mb']:.0f}MB | "
              f"Naive={r['naive_real_gpu_mb']:.0f}MB({r['naive_compression']}x) | "
              f"Triton={r['triton_mb']:.0f}MB({r['triton_compression']}x) | "
              f"prefill={r['prefill_ms']}ms")
    except torch.cuda.OutOfMemoryError:
        print(f"  ctx={ctx:6d} | OOM at FP16 β€” compressed methods would fit βœ“")
        results.append({
            "context_len":    ctx,
            "peak_memory_gb": "OOM",
            "fp16_mb":        round(ctx * num_layers * 2 * 8 * 128 * 2 / 1e6, 2),
            "note":           "FP16 OOM"
        })
        break

# save
out_path = f"{results_dir}/long_context_results.json"
with open(out_path, "w") as f:
    json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
print(f"\nβœ… Saved to {out_path}")